home *** CD-ROM | disk | FTP | other *** search
/ Turnbull China Bikeride / Turnbull China Bikeride - Disc 1.iso / ARGONET / PD / MATHS / RLAB / RLAB125.ZIP / !RLaB / examples / nn-test < prev    next >
Text File  |  1995-05-20  |  5KB  |  190 lines

  1. //
  2. // @(#) mail.r
  3. //
  4. // ([{
  5. //
  6.  
  7. // necessarily global parameters:
  8. //   fircoeff
  9. //   nfir
  10. //   nsaved
  11. //   sigsave
  12. //   firout
  13. //   desired
  14. //   iter
  15.  
  16. // initialize model filter
  17. initfir = function () 
  18. {
  19.   global(niter, fircoeff, nfir, nsaved, sigsave, firout, ...
  20.          desired, iter, ninput, input, weight1, theta1, ...
  21.          weight2, theta2, ndf, nhidden, mu, esquare, output, ...
  22.          nwsave, wsave, skiplimit)
  23.  
  24.   // model filter parameters
  25.   // fircoeff = [0.2;-0.4;1;-0.6;0.4];    // model filter coefficients
  26.   fircoeff = [1];
  27.   nfir = size(fircoeff)[1];
  28.   nsaved = 20;                    // signal delay line length, must
  29.   // be >= nfir
  30.   rand("default");
  31.   sigsave = rand(nsaved, 1);        // initialize
  32.   
  33. }
  34.  
  35. // model filter time step
  36. stepfir = function () 
  37. {
  38.   global(niter, fircoeff, nfir, nsaved, sigsave, firout, ...
  39.          desired, iter, ninput, input, weight1, theta1, ...
  40.          weight2, theta2, ndf, nhidden, mu, esquare, output, ...
  41.          nwsave, wsave, skiplimit)
  42.   
  43.   // excitation signal generation
  44.   rand("default");
  45.   //    signal = (rand() > 0.5) * 0.8 + 0.1; // exp
  46.   //    signal = (rand() > 0.5) * 1.8 - 0.9; // tanh
  47.   signal = (rand() > 0.5) * 1.8 - 0.9; // simple
  48.   
  49.   // filter update
  50.   sigsave = [signal; sigsave[1:(nsaved - 1)]];
  51.   firout = sigsave[1:nfir]' * fircoeff;
  52.   
  53.   // nonlinearity
  54.   // firout = tanh(3 * firout);
  55.   // firout = 2 * sin(3 * firout) + firout;
  56.   
  57.   // noise
  58.   rand("normal", 0, 0.1);
  59.   firout = firout + rand();
  60.   
  61.   // impulse noise
  62.   rand("default");
  63.   if (rand() > 0.7) 
  64.   {
  65.     firout = firout + 5;
  66.   }
  67.   
  68.   // what we would like to adapt to
  69.   // desired = firout;
  70.   desired = sigsave[3;];
  71.   
  72. }
  73.  
  74. bp2init = function () 
  75. {
  76.   global(niter, fircoeff, nfir, nsaved, sigsave, firout, ...
  77.          desired, iter, ninput, input, weight1, theta1, ...
  78.          weight2, theta2, ndf, nhidden, mu, esquare, output, ...
  79.          nwsave, wsave, skiplimit)
  80.  
  81.   // backprop parameters
  82.   ninput = 10;            // input vector length
  83.   input = zeros(ninput, 1);    // initialize
  84.   ndf = 0;            // number of feedback taps
  85.   nhidden = 10;
  86.   rand("default");
  87.   weight1 = (2 * rand(nhidden, ninput ) - 1) * 0.1;
  88.   theta1  = (2 * rand(nhidden, 1      ) - 1) * 0.1;
  89.   weight2 = (2 * rand(1      , nhidden) - 1) * 0.1;
  90.   theta2  = (2 * rand(1      , 1      ) - 1) * 0.1;
  91.   mu = 1.0;
  92.  
  93.   // miscellanea
  94.   niter = 20000;
  95.   skiplimit = 0;
  96.   nwsave = niter - skiplimit;
  97.   wsave = zeros(nwsave, 5);
  98.   esquare = 0;
  99.   output = 0;        // so that we can shift the first value into input[]
  100.   
  101.   initfir();
  102.   
  103. }
  104.  
  105. bp2steps = function (fromiter, toiter, skiplimit) 
  106. {
  107.   global(niter, fircoeff, nfir, nsaved, sigsave, firout, ...
  108.          desired, iter, ninput, input, weight1, theta1, ...
  109.          weight2, theta2, ndf, nhidden, mu, esquare, output, ...
  110.          nwsave, wsave, skiplimit)
  111.  
  112.   for (iter in fromiter:toiter) 
  113.   {
  114.     
  115.     mu = 0.01 * (niter - iter + 1) / niter;
  116.     
  117.     stepfir();
  118.     
  119.     // reconstrunction filter data generation
  120.     input = [firout; input[1:(ninput - 1)]];
  121.     
  122.     // forward pass
  123.     s1 = weight1 * input + theta1;
  124.     hidden = s1 ./ (1 + abs(s1)); // simple
  125.     s2 = weight2 * hidden + theta2;
  126.     output = s2 ./ (1 + abs(s2)); // simple
  127.     
  128.     // error vectors
  129.     e2 = (mu * (1 ./ ((1 + abs(s2)) .* (1 + abs(s2))) + 0.05)) ...
  130.           .* (desired - output); // simple
  131.     e1 = (1 ./ ((1 + abs(s1)) .* (1 + abs(s1))) + 0.05) ...
  132.           .* (weight2' * e2); // simple
  133.  
  134.     // weight update
  135.     weight1 = weight1 + e1 * input';
  136.     weight2 = weight2 + e2 * hidden';
  137.     
  138.     // threshold update
  139.     theta1 = theta1 + e1;
  140.     theta2 = theta2 + e2;
  141.     
  142.     err = desired - output;
  143.     esquare = 0.995 * esquare + 0.005 * (err' * err);
  144.     if (iter > skiplimit) 
  145.     {
  146.       wsave[(iter - skiplimit);] = [desired, firout, output, err, esquare];
  147.     }
  148.     
  149.   }
  150.   
  151. }
  152.  
  153. bp2 = function () 
  154. {
  155.   global(niter, fircoeff, nfir, nsaved, sigsave, firout, ...
  156.          desired, iter, ninput, input, weight1, theta1, ...
  157.          weight2, theta2, ndf, nhidden, mu, esquare, output, ...
  158.          nwsave, wsave, skiplimit)
  159.  
  160.   iter = 1;
  161.   toiter = 0;
  162.   while (iter < niter) 
  163.   {
  164.     fromiter = toiter + 1;
  165.     toiter = fromiter + 999;
  166.     if (toiter > niter) 
  167.     {
  168.       toiter = niter;
  169.     }
  170.     
  171.     bp2steps(fromiter, toiter, skiplimit);
  172.     
  173.     printf("Iteration %d: esquare = %g,\n", iter, esquare);
  174.     printf("   weight1^2 = %g, theta1^2 = %g,\n", ...
  175.                trace(weight1 * weight1'), theta1' * theta1);
  176.     printf("   weight2^2 = %g, theta2^2 = %g\n", ...
  177.                weight2 * weight2', theta2' * theta2);
  178.   }
  179.   
  180.   plotdata = [((skiplimit + 1):niter)', wsave];
  181.   
  182.   plgrid();
  183.   pltitle ( "RLaB Neural Net Example (contributed)" );
  184.   xlabel ( "" );
  185.   ylabel ( "" );
  186.   plot(plotdata[nwsave-100:nwsave;]);
  187.   
  188. }
  189.  
  190.